package server

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"gitee.com/yangzx6606/core/pattern"
	log "github.com/golang/glog"
	"github.com/quic-go/quic-go"
	"math/big"
	"sync"
)

type QuicSession struct {
	AbstractSession
	connection quic.Connection
	stream     quic.Stream
	address    string
	isOpened   bool
}

func (s *QuicSession) Open() error {
	s.isOpened = true
	go s.doWork()
	return nil
}

func (s *QuicSession) GetKey() any {
	return s.connection
}

func (s *QuicSession) Write(packet *Packet) error {
	if packet.Data == nil {
		return errors.New("参数错误")
	}
	_, err := s.stream.Write(packet.Data)
	return err
}

func (s *QuicSession) GetAddress() string {
	return s.address
}

func (s *QuicSession) Close() {
	if s.connection != nil {
		_ = s.connection.CloseWithError(quic.ApplicationErrorCode(0), "close")
		_ = s.stream.Close()
		s.connection = nil
		s.stream = nil
	}
}

func (s *QuicSession) doWork() {
	buffer := make([]byte, 1024*128)
	var err error
	s.stream, err = s.connection.AcceptStream(context.Background())
	if err != nil {
		log.Error("打开客户端传流失败:", err)
		s.Notify(&Packet{Type: PacketTypeConnectLost, Session: s})
		return
	}
	for s.isOpened {
		n, err := s.stream.Read(buffer)
		if err != nil {
			log.Error("读取客户端数据失败:", err)
			_ = s.stream.Close()
			s.Notify(&Packet{Type: PacketTypeConnectLost, Session: s})
			break
		}
		s.Notify(&Packet{
			Type:    PacketTypeData,
			Data:    buffer[:n],
			Session: s,
		})
	}
}

type QuicServer struct {
	pattern.AbstractSubject[*Packet]
	address   string
	listener  quic.Listener
	isOpened  bool
	sessions  map[quic.Connection]*QuicSession
	mutex     sync.Mutex
	protocols []string
}

func NewQuicServer(address string, protocols []string) *QuicServer {
	return &QuicServer{address: address, protocols: protocols, sessions: make(map[quic.Connection]*QuicSession)}
}

func (s *QuicServer) Open() error {
	if s.isOpened {
		return errors.New("服务已打开")
	}
	var err error
	cfg, err := s.generateTLSConfig()
	s.listener, err = quic.ListenAddr(s.address, cfg, nil)
	if err != nil {
		return err
	}
	s.isOpened = true
	go s.accept()
	return nil
}

func (s *QuicServer) Close() {
	if s.isOpened {
		_ = s.listener.Close()
	}
	for _, v := range s.sessions {
		_ = v.Close
	}
	s.sessions = make(map[quic.Connection]*QuicSession)
}

func (s *QuicServer) Write(data *Packet) error {
	if data.Session == nil || len(data.Data) == 0 {
		return errors.New("参数错误")
	}
	s.mutex.Lock()
	defer s.mutex.Unlock()
	key := data.Session.GetKey().(quic.Connection)
	if v, ok := s.sessions[key]; ok {
		return v.Write(data)
	}
	return errors.New("连接已断开")
}

func (s *QuicServer) accept() {
	for s.isOpened {
		conn, err := s.listener.Accept(context.Background())
		if err != nil {
			if s.isOpened {
				log.Error("监听客户端连接失败：", err)
				break
			}
		}
		log.Info("客户端连接：", conn.RemoteAddr().String())

		session := &QuicSession{
			address:    conn.RemoteAddr().String(),
			connection: conn,
			stream:     nil,
		}
		session.AttachObserver(s, session)
		log.Info("开始监听数据")
		err = session.Open()
		if err != nil {
			log.Error("打开连接失败：", err)
		} else {
			s.sessions[conn] = session
			s.Notify(&Packet{
				Type:    PacketTypeConnected,
				Session: session,
			})
		}
	}
}

func (c *QuicServer) OnData(packet *Packet, sess any) {
	c.Notify(packet)
	if packet.Type == PacketTypeConnectLost {
		key := packet.Session.GetKey().(quic.Connection)
		log.Info("删除会话：", key.RemoteAddr().String())
		delete(c.sessions, key)
	}
}

func (c *QuicServer) generateTLSConfig() (*tls.Config, error) {
	key, err := rsa.GenerateKey(rand.Reader, 1024)
	if err != nil {
		return nil, err
	}
	template := x509.Certificate{SerialNumber: big.NewInt(1)}
	certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
	if err != nil {
		return nil, err
	}
	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

	tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
	if err != nil {
		return nil, err
	}
	return &tls.Config{
		Certificates: []tls.Certificate{tlsCert},
		NextProtos:   c.protocols,
	}, nil
}
